import os
import argparse
import json
from tqdm import tqdm
import torch
import torch.utils.data
from transformers import AutoTokenizer, AutoModel


from modelslim.pytorch.llm_ptq.llm_ptq_tools import Calibrator as SparseQuantCalibrator
from modelslim.pytorch.llm_ptq.llm_ptq_tools import QuantConfig as SparseQuantConfig


def parse_args():
    parser = argparse.ArgumentParser(
        description="generate Model weights for quant.")
    parser.add_argument(
        "--fp16_path",
        default=f"{os.environ['PROJECT_PATH']}/resource/sparse/chatglm2",
        help="Location of Model weights, which contains pytorch_model-00001-of-00007.bin and others",
    )
    parser.add_argument(
        "--data_path",
        default=f"{os.environ['PROJECT_PATH']}/resource/CEval",
        help="Location to read and write the quant weights",
    )
    parser.add_argument(
        "--save_path",
        default=f"{os.environ['PROJECT_PATH']}/output/sparse",
        help="Location to read and write the quant weights",
    )
    parsed = parser.parse_args()
    return parsed

args = parse_args()
data_path = args.data_path
fp16_path = args.fp16_path
save_path = args.save_path
tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path=fp16_path,
                                          trust_remote_code=True, 
                                          local_files_only=True)

model = AutoModel.from_pretrained(pretrained_model_name_or_path=fp16_path,
                                  torch_dtype=torch.float32, 
                                  trust_remote_code=True, 
                                  local_files_only=True)

choices = ["A", "B", "C", "D"]
choice_tokens = [tokenizer.encode(choice, add_special_tokens=False)[0] for choice in choices]

def build_prompt(text):
    return "[Round {}]\n\n问：{}\n\n答：".format(1, text)

extraction_prompt = '综上所述，ABCD中正确的选项是：'

def get_dataset(bs=1):
    with torch.no_grad():
        dataset_all = []          
        entry = data_path + "/val/Other/civil_servant.jsonl"
        dataset_cur = []
        dataset = []
        with open(entry, encoding='utf-8') as file:
            for line in file:
                dataset.append(json.loads(line))
        dataset = dataset[:1]
        correct = 0
        dataloader = torch.utils.data.DataLoader(dataset, batch_size=bs)
        for batch in tqdm(dataloader):
            texts = batch["inputs_pretokenized"]
            queries = [build_prompt(query) for query in texts]
            inputs = tokenizer(queries, padding=True, return_tensors="pt", truncation=True, max_length=2048).to(
                'cpu')
            
            outputs = model.generate(**inputs, do_sample=False, max_new_tokens=512)
            intermediate_outputs = []
            for idx in range(len(outputs)):
                output = outputs.tolist()[idx][len(inputs["input_ids"][idx]):]
                response = tokenizer.decode(output)
                print('response: ', response)
                intermediate_outputs.append(response)
            answer_texts = [text + intermediate + "\n" + extraction_prompt for text, intermediate in
                            zip(texts, intermediate_outputs)]
            input_tokens = [build_prompt(answer_text) for answer_text in answer_texts]
            inputs = tokenizer(input_tokens, padding=True, return_tensors="pt", truncation=True,
                               max_length=2048).to('cpu')
            dataset_tmp = [inputs.data['input_ids'], inputs.data['position_ids'], inputs.data['attention_mask'], None, None, None, None, None,
                           None, None, True]
            dataset_cur.append(dataset_tmp)
        dataset_all.extend(dataset_cur)
    return dataset_all

dataset_calib = get_dataset()
print('len of calib_dataset: ', len(dataset_calib))

w_bit = 4
fraction = 0.011
powerquant = False
mm_tensor = False

quant_config = SparseQuantConfig(w_bit=w_bit,
                           disable_names=['transformer.encoder.layers.0.self_attention.query_key_value',
                                   'transformer.encoder.layers.0.self_attention.dense',
                                   'transformer.encoder.layers.0.mlp.dense_h_to_4h',
                                    'transformer.encoder.layers.0.mlp.dense_4h_to_h', 
                                    'transformer.output_layer'],
                           dev_type='cpu',
                           act_method=3,
                           pr=2.0,
                           fraction=fraction,
                           nonuniform=powerquant,
                           mm_tensor=mm_tensor,
                           co_sparse=True)

calibrator = SparseQuantCalibrator(model, quant_config, calib_data=dataset_calib, disable_level='L0')  # 内部回退两层
calibrator.run(int_infer=False)
calibrator.save(save_path) #存储量化参数用于部署，在存储量化参数过程中，存在反序列化风险，已通过将保存的量化结果文件夹权限设置为750，将量化结果文件权限设置为400来消减该风险